{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "zsEdjC-kxfoC"
   },
   "source": [
    "# Text drift detection on IMDB movie reviews\n",
    "\n",
    "## Method\n",
    "\n",
    "We detect drift on text data using both the [Maximum Mean Discrepancy](https://docs.seldon.io/projects/alibi-detect/en/stable/cd/methods/mmddrift.html) and [Kolmogorov-Smirnov (K-S)](https://docs.seldon.io/projects/alibi-detect/en/stable/cd/methods/ksdrift.html) detectors. In this example notebook we will focus on detecting covariate shift $\\Delta p(x)$ as detecting predicted label distribution drift does not differ from other modalities (check [K-S](https://docs.seldon.io/projects/alibi-detect/en/stable/examples/cd_ks_cifar10.html#BBSDs) and [MMD](https://docs.seldon.io/projects/alibi-detect/en/stable/examples/cd_mmd_cifar10.html#BBSDs) drift on CIFAR-10).\n",
    "\n",
    "It becomes however a little bit more involved when we want to pick up input data drift $\\Delta p(x)$. When we deal with tabular or image data, we can either directly apply the two sample hypothesis test on the input or do the test after a preprocessing step with for instance a randomly initialized encoder as proposed in [Failing Loudly: An Empirical Study of Methods for Detecting Dataset Shift](https://arxiv.org/abs/1810.11953) (they call it an Untrained AutoEncoder or *UAE*). It is not as straightforward when dealing with text, both in string or tokenized format as they don't directly represent the semantics of the input.\n",
    "\n",
    "As a result, we extract (contextual) embeddings for the text and detect drift on those. This procedure has a significant impact on the type of drift we detect. Strictly speaking we are not detecting $\\Delta p(x)$ anymore since the whole training procedure (objective function, training data etc) for the (pre)trained embeddings has an impact on the embeddings we extract.\n",
    "\n",
    "The library contains functionality to leverage pre-trained embeddings from [HuggingFace's transformer package](https://github.com/huggingface/transformers) but also allows you to easily use your own embeddings of choice. Both options are illustrated with examples in this notebook.\n",
    "\n",
    "\n",
    "<div class=\"alert alert-info\">\n",
    "Note\n",
    "\n",
    "As is done in this example, it is recommended to pass text data to detectors as a list of strings (`List[str]`). This allows for seamless integration with HuggingFace's transformers library.\n",
    "\n",
    "One exception to the above is when custom embeddings are used. Here, it is important to ensure that the data is passed to the custom embedding model in a compatible format. In [the final example](#Train-embeddings-from-scratch), a `preprocess_batch_fn` is defined in order to convert `list`'s to the `np.ndarray`'s expected by the custom TensorFlow embedding.\n",
    "    \n",
    "</div>\n",
    "\n",
    "## Backend\n",
    "\n",
    "The method works with both the **PyTorch** and **TensorFlow** frameworks for the statistical tests and preprocessing steps. Alibi Detect does however not install PyTorch for you. \n",
    "Check the [PyTorch docs](https://pytorch.org/) how to do this.\n",
    "\n",
    "## Dataset\n",
    "\n",
    "Binary sentiment classification [dataset](https://ai.stanford.edu/~amaas/data/sentiment/) containing $25,000$ movie reviews for training and $25,000$ for testing. Install the `nlp` library to fetch the dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "r6ttO1nVxfoG"
   },
   "outputs": [],
   "source": [
    "!pip install nlp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "qG5BL45LxfoG"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"TF_USE_LEGACY_KERAS\"] = \"1\"\n",
    "\n",
    "import nlp\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from transformers import AutoTokenizer\n",
    "from alibi_detect.cd import KSDrift, MMDDrift\n",
    "from alibi_detect.saving import save_detector, load_detector"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xhB9-toUxfoH"
   },
   "source": [
    "### Load tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "ewiBdavQxfoH",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model_name = 'bert-base-cased'\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ai9hjQaOxfoI"
   },
   "source": [
    "### Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "AeINSkzrxfoI"
   },
   "outputs": [],
   "source": [
    "def load_dataset(dataset: str, split: str = 'test'):\n",
    "    data = nlp.load_dataset(dataset)\n",
    "    X, y = [], []\n",
    "    for x in data[split]:\n",
    "        X.append(x['text'])\n",
    "        y.append(x['label'])\n",
    "    X = np.array(X)\n",
    "    y = np.array(y)\n",
    "    return X, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "51MqauScxfoJ",
    "outputId": "505cc4cf-fe91-4f7d-d009-aff690195f04",
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(25000,) (25000,)\n"
     ]
    }
   ],
   "source": [
    "X, y = load_dataset('imdb', split='train')\n",
    "print(X.shape, y.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "MMN3B9k-xfoK"
   },
   "source": [
    "Let's take a look at respectively a negative and positive review:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "GwMZIdXzxfoK",
    "outputId": "fcd16581-0b16-479b-b375-d4bb8122bb64"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Negative\n",
      "This is one of the dumbest films, I've ever seen. It rips off nearly ever type of thriller and manages to make a mess of them all.<br /><br />There's not a single good line or character in the whole mess. If there was a plot, it was an afterthought and as far as acting goes, there's nothing good to say so Ill say nothing. I honestly cant understand how this type of nonsense gets produced and actually released, does somebody somewhere not at some stage think, 'Oh my god this really is a load of shite' and call it a day. Its crap like this that has people downloading illegally, the trailer looks like a completely different film, at least if you have download it, you haven't wasted your time or money Don't waste your time, this is painful.\n"
     ]
    }
   ],
   "source": [
    "labels = ['Negative', 'Positive']\n",
    "print(labels[y[-1]])\n",
    "print(X[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "TCojQsnCxfoK",
    "outputId": "66fee3b1-be31-4520-a9d3-f12685d3c9be"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Positive\n",
      "Brilliant over-acting by Lesley Ann Warren. Best dramatic hobo lady I have ever seen, and love scenes in clothes warehouse are second to none. The corn on face is a classic, as good as anything in Blazing Saddles. The take on lawyers is also superb. After being accused of being a turncoat, selling out his boss, and being dishonest the lawyer of Pepto Bolt shrugs indifferently \"I'm a lawyer\" he says. Three funny words. Jeffrey Tambor, a favorite from the later Larry Sanders show, is fantastic here too as a mad millionaire who wants to crush the ghetto. His character is more malevolent than usual. The hospital scene, and the scene where the homeless invade a demolition site, are all-time classics. Look for the legs scene and the two big diggers fighting (one bleeds). This movie gets better each time I see it (which is quite often).\n"
     ]
    }
   ],
   "source": [
    "print(labels[y[2]])\n",
    "print(X[2])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "hjVFuRnKxfoL"
   },
   "source": [
    "We split the original test set in a reference dataset and a dataset which should not be rejected under the *H0* of the statistical test. We also create imbalanced datasets and inject selected words in the reference set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "id": "YjV-qwNuxfoL"
   },
   "outputs": [],
   "source": [
    "def random_sample(X: np.ndarray, y: np.ndarray, proba_zero: float, n: int):\n",
    "    if len(y.shape) == 1:\n",
    "        idx_0 = np.where(y == 0)[0]\n",
    "        idx_1 = np.where(y == 1)[0]\n",
    "    else:\n",
    "        idx_0 = np.where(y[:, 0] == 1)[0]\n",
    "        idx_1 = np.where(y[:, 1] == 1)[0]\n",
    "    n_0, n_1 = int(n * proba_zero), int(n * (1 - proba_zero))\n",
    "    idx_0_out = np.random.choice(idx_0, n_0, replace=False)\n",
    "    idx_1_out = np.random.choice(idx_1, n_1, replace=False)\n",
    "    X_out = np.concatenate([X[idx_0_out], X[idx_1_out]])\n",
    "    y_out = np.concatenate([y[idx_0_out], y[idx_1_out]])\n",
    "    return X_out.tolist(), y_out.tolist()\n",
    "\n",
    "\n",
    "def padding_last(x: np.ndarray, seq_len: int) -> np.ndarray:\n",
    "    try:  # try not to replace padding token\n",
    "        last_token = np.where(x == 0)[0][0]\n",
    "    except:  # no padding\n",
    "        last_token = seq_len - 1\n",
    "    return 1, last_token\n",
    "\n",
    "\n",
    "def padding_first(x: np.ndarray, seq_len: int) -> np.ndarray:\n",
    "    try:  # try not to replace padding token\n",
    "        first_token = np.where(x == 0)[0][-1] + 2\n",
    "    except:  # no padding\n",
    "        first_token = 0\n",
    "    return first_token, seq_len - 1\n",
    "\n",
    "\n",
    "def inject_word(token: int, X: np.ndarray, perc_chg: float, padding: str = 'last'):\n",
    "    seq_len = X.shape[1]\n",
    "    n_chg = int(perc_chg * .01 * seq_len)\n",
    "    X_cp = X.copy()\n",
    "    for _ in range(X.shape[0]):\n",
    "        if padding == 'last':\n",
    "            first_token, last_token = padding_last(X_cp[_, :], seq_len)\n",
    "        else:\n",
    "            first_token, last_token = padding_first(X_cp[_, :], seq_len)\n",
    "        if last_token <= n_chg:\n",
    "            choice_len = seq_len\n",
    "        else:\n",
    "            choice_len = last_token\n",
    "        idx = np.random.choice(np.arange(first_token, choice_len), n_chg, replace=False)\n",
    "        X_cp[_, idx] = token\n",
    "    return X_cp.tolist()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "YkeJNC6UxfoL"
   },
   "source": [
    "Reference, *H0* and imbalanced data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "id": "6aZFd5fgxfoL"
   },
   "outputs": [],
   "source": [
    "# proba_zero = fraction with label 0 (=negative sentiment)\n",
    "n_sample = 1000\n",
    "X_ref = random_sample(X, y, proba_zero=.5, n=n_sample)[0]\n",
    "X_h0 = random_sample(X, y, proba_zero=.5, n=n_sample)[0]\n",
    "n_imb = [.1, .9]\n",
    "X_imb = {_: random_sample(X, y, proba_zero=_, n=n_sample)[0] for _ in n_imb}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "cRyDI1e4xfoM"
   },
   "source": [
    "Inject words in reference data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "KCV-JUnnxfoM"
   },
   "outputs": [],
   "source": [
    "words = ['fantastic', 'good', 'bad', 'horrible']\n",
    "perc_chg = [1., 5.]  # % of tokens to change in an instance\n",
    "\n",
    "words_tf = tokenizer(words)['input_ids']\n",
    "words_tf = [token[1:-1][0] for token in words_tf]\n",
    "max_len = 100\n",
    "tokens = tokenizer(X_ref, pad_to_max_length=True, \n",
    "                   max_length=max_len, return_tensors='tf')\n",
    "X_word = {}\n",
    "for i, w in enumerate(words_tf):\n",
    "    X_word[words[i]] = {}\n",
    "    for p in perc_chg:\n",
    "        x = inject_word(w, tokens['input_ids'].numpy(), p)\n",
    "        dec = tokenizer.batch_decode(x, **dict(skip_special_tokens=True))\n",
    "        X_word[words[i]][p] = dec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "nf6SjGzKAPvr",
    "outputId": "ea0c653b-0d28-4901-f9ca-5e26ff1e23e5"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(1000, 100), dtype=int32, numpy=\n",
       "array([[  101,  1188,  1794, ...,     0,     0,     0],\n",
       "       [  101,  1556,  5122, ...,  1307,  1800,   102],\n",
       "       [  101,  3406,  4720, ...,  5674,  2723,   102],\n",
       "       ...,\n",
       "       [  101,  2082,  1122, ...,  1641,   107,   102],\n",
       "       [  101,  1124,   118, ...,  1155,  1104,   102],\n",
       "       [  101,  1249, 24017, ...,     0,     0,     0]], dtype=int32)>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokens['input_ids']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "pMJ-JH6WxfoM"
   },
   "source": [
    "## Preprocessing\n",
    "\n",
    "First we need to specify the type of embedding we want to extract from the BERT model. We can extract embeddings from the ...\n",
    "\n",
    "- **pooler_output**: Last layer hidden-state of the first token of the sequence (classification token; CLS) further processed by a Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence prediction (classification) objective during pre-training. **Note**: this output is usually not a good summary of the semantic content of the input, you’re often better with averaging or pooling the sequence of hidden-states for the whole input sequence.\n",
    "\n",
    "- **last_hidden_state**: Sequence of hidden states at the output of the last layer of the model, averaged over the tokens.\n",
    "\n",
    "- **hidden_state**: Hidden states of the model at the output of each layer, averaged over the tokens.\n",
    "\n",
    "- **hidden_state_cls**: See *hidden_state* but use the CLS token output.\n",
    "\n",
    "If *hidden_state* or *hidden_state_cls* is used as embedding type, you also need to pass the layer numbers used to extract the embedding from. As an example we extract embeddings from the last 8 hidden states."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-kd4nbcXxfoM",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from alibi_detect.models.tensorflow import TransformerEmbedding\n",
    "\n",
    "emb_type = 'hidden_state'\n",
    "n_layers = 8\n",
    "layers = [-_ for _ in range(1, n_layers + 1)]\n",
    "\n",
    "embedding = TransformerEmbedding(model_name, emb_type, layers)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tV3bnMwGxfoM"
   },
   "source": [
    "Let's check what an embedding looks like:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "wgJQucxSxfoN",
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "tokens = tokenizer(list(X[:5]), pad_to_max_length=True, \n",
    "                   max_length=max_len, return_tensors='tf')\n",
    "x_emb = embedding(tokens)\n",
    "print(x_emb.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6IcVvbaYxfoN"
   },
   "source": [
    "So the BERT model's embedding space used by the drift detector consists of a $768$-dimensional vector for each instance. We will therefore first apply a dimensionality reduction step with an Untrained AutoEncoder (*UAE*) before conducting the statistical hypothesis test. We use the embedding model as the input for the UAE which then projects the embedding on a lower dimensional space."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "id": "T6rQpEr5xfoN"
   },
   "outputs": [],
   "source": [
    "tf.random.set_seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "id": "yvKnc6ZxxfoN"
   },
   "outputs": [],
   "source": [
    "from alibi_detect.cd.tensorflow import UAE\n",
    "\n",
    "enc_dim = 32\n",
    "shape = (x_emb.shape[1],)\n",
    "\n",
    "uae = UAE(input_layer=embedding, shape=shape, enc_dim=enc_dim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "mopByQSzxfoN"
   },
   "source": [
    "Let's test this again:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "aAwVIRRexfoN",
    "outputId": "9b1ece6b-ce9c-4142-da87-feb9bba97707"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(5, 32)\n"
     ]
    }
   ],
   "source": [
    "emb_uae = uae(tokens)\n",
    "print(emb_uae.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Vx73zIu3xfoN"
   },
   "source": [
    "## K-S detector\n",
    "\n",
    "### Initialize\n",
    "\n",
    "We proceed to initialize the drift detector. From here on the detector works the same as for other modalities such as images. Please check the [images](https://docs.seldon.io/projects/alibi-detect/en/stable/examples/cd_ks_cifar10.html) example or the [K-S detector documentation](https://docs.seldon.io/projects/alibi-detect/en/stable/cd/methods/ksdrift.html) for more information about each of the possible parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "4OolVqEUxfoN",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "from alibi_detect.cd.tensorflow import preprocess_drift\n",
    "\n",
    "# define preprocessing function\n",
    "preprocess_fn = partial(preprocess_drift, model=uae, tokenizer=tokenizer, \n",
    "                        max_len=max_len, batch_size=32)\n",
    "\n",
    "# initialize detector\n",
    "cd = KSDrift(X_ref, p_val=.05, preprocess_fn=preprocess_fn, input_shape=(max_len,))\n",
    "\n",
    "# we can also save/load an initialised detector\n",
    "filepath = 'my_path'  # change to directory where detector is saved\n",
    "save_detector(cd, filepath)\n",
    "cd = load_detector(filepath)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "lpZbGa-2xfoO"
   },
   "source": [
    "### Detect drift\n",
    "\n",
    "Let’s first check if drift occurs on a similar sample from the training set as the reference data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "E1E56gnvxfoO",
    "outputId": "af5dedcf-9be2-4c9c-8c4c-b44196af7d42"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Drift? No!\n",
      "p-value: [0.31356168 0.18111965 0.60991895 0.43243074 0.6852314  0.722555\n",
      " 0.28769323 0.18111965 0.50035924 0.9134755  0.40047103 0.79439443\n",
      " 0.79439443 0.722555   0.5726548  0.1640792  0.9540582  0.60991895\n",
      " 0.5726548  0.5726548  0.31356168 0.40047103 0.6852314  0.34099194\n",
      " 0.5726548  0.07762147 0.79439443 0.09710453 0.5726548  0.79439443\n",
      " 0.7590978  0.26338065]\n"
     ]
    }
   ],
   "source": [
    "preds_h0 = cd.predict(X_h0)\n",
    "labels = ['No!', 'Yes!']\n",
    "print('Drift? {}'.format(labels[preds_h0['data']['is_drift']]))\n",
    "print('p-value: {}'.format(preds_h0['data']['p_val']))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Q1Ub4WpYxfoO"
   },
   "source": [
    "Detect drift on imbalanced and perturbed datasets:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "2_8Dw8QBxfoO",
    "outputId": "c765e7f1-0ff3-43c9-d674-c573d137ec56"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "% negative sentiment 10.0\n",
      "Drift? Yes!\n",
      "p-value: [4.32430744e-01 4.00471032e-01 5.46463318e-02 7.76214674e-02\n",
      " 1.08282514e-01 1.12110768e-02 6.91903234e-02 2.82894098e-03\n",
      " 8.59294355e-01 6.47557259e-01 1.33834302e-01 7.94394433e-01\n",
      " 4.28151786e-02 2.87693232e-01 6.09918952e-01 1.33834302e-01\n",
      " 2.40603596e-01 9.71045271e-02 7.76214674e-02 9.35580969e-01\n",
      " 2.87693232e-01 2.92505771e-02 4.00471032e-01 6.09918952e-01\n",
      " 2.87693232e-01 5.06567594e-04 1.64079204e-01 6.09918952e-01\n",
      " 1.33834302e-01 2.19330013e-01 7.94394433e-01 2.56591532e-02]\n",
      "\n",
      "% negative sentiment 90.0\n",
      "Drift? Yes!\n",
      "p-value: [7.36993998e-02 1.37563676e-01 5.86588383e-02 5.07961273e-01\n",
      " 8.37696046e-02 8.80799629e-03 1.23670578e-01 1.76981179e-04\n",
      " 3.21924835e-01 1.20594716e-02 8.43600273e-01 4.08206195e-01\n",
      " 1.69703156e-01 5.79056978e-01 6.32701874e-01 4.48510349e-02\n",
      " 5.07465303e-01 6.64306164e-04 5.23085408e-02 3.78374875e-01\n",
      " 6.65342569e-01 4.06090707e-01 6.21288121e-01 5.85612692e-02\n",
      " 5.87646782e-01 7.55570829e-03 8.99188042e-01 1.18489005e-02\n",
      " 6.68586135e-01 1.01421457e-02 7.97733963e-02 1.73885196e-01]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for k, v in X_imb.items():\n",
    "    preds = cd.predict(v)\n",
    "    print('% negative sentiment {}'.format(k * 100))\n",
    "    print('Drift? {}'.format(labels[preds['data']['is_drift']]))\n",
    "    print('p-value: {}'.format(preds['data']['p_val']))\n",
    "    print('')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "K7GYX5_BxfoO",
    "outputId": "98d4a86a-e741-4709-a95e-a72e45d95cc8",
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Word: fantastic -- % perturbed: 1.0\n",
      "Drift? No!\n",
      "p-value: [0.8879386  0.01711409 0.2406036  0.9134755  0.21933001 0.04281518\n",
      " 0.03778438 0.28769323 0.3699725  0.996931   0.8879386  0.43243074\n",
      " 0.01121108 0.6852314  0.99870795 0.996931   0.93558097 0.99365413\n",
      " 0.02246371 0.60991895 0.8879386  0.34099194 0.09710453 0.8879386\n",
      " 0.1338343  0.06155144 0.85929435 0.99365413 0.07762147 0.07762147\n",
      " 0.9882611  0.85929435]\n",
      "\n",
      "Word: fantastic -- % perturbed: 5.0\n",
      "Drift? Yes!\n",
      "p-value: [1.29345525e-02 1.69780876e-14 1.52437299e-11 5.72654784e-01\n",
      " 1.85489473e-08 1.88342838e-17 6.14975981e-09 4.28151786e-02\n",
      " 5.62237052e-13 2.13202584e-05 4.28151786e-02 1.97469308e-09\n",
      " 0.00000000e+00 1.48931602e-02 9.68870163e-01 1.29345525e-02\n",
      " 2.63380647e-01 1.08282514e-01 1.04535818e-26 4.28151786e-02\n",
      " 2.13202584e-05 3.47411038e-14 1.09291570e-20 1.08282514e-01\n",
      " 5.68982140e-18 1.69780876e-14 1.64079204e-01 4.00471032e-01\n",
      " 3.12689441e-34 3.89208371e-27 2.86525619e-06 1.71956726e-05]\n",
      "\n",
      "Word: good -- % perturbed: 1.0\n",
      "Drift? Yes!\n",
      "p-value: [3.40991944e-01 9.80161786e-01 1.08282514e-01 9.98707950e-01\n",
      " 1.48338065e-01 9.35580969e-01 7.59097815e-01 9.88261104e-01\n",
      " 8.87938619e-01 6.47557259e-01 9.68870163e-01 7.94394433e-01\n",
      " 8.69054198e-02 9.99999642e-01 9.96931016e-01 5.72654784e-01\n",
      " 9.99870896e-01 4.32430744e-01 9.99870896e-01 2.92505771e-02\n",
      " 9.13475513e-01 9.13475513e-01 4.65766221e-01 9.35580969e-01\n",
      " 8.87938619e-01 9.98707950e-01 9.80161786e-01 9.99972701e-01\n",
      " 7.59097815e-01 1.34916729e-04 9.96931016e-01 9.68870163e-01]\n",
      "\n",
      "Word: good -- % perturbed: 5.0\n",
      "Drift? Yes!\n",
      "p-value: [6.1319246e-16 8.5929435e-01 8.4248814e-24 5.3605431e-01 6.1410643e-10\n",
      " 1.9951835e-01 2.9080641e-04 3.6997250e-01 2.4072561e-04 3.3837957e-10\n",
      " 9.5405817e-01 8.6666952e-04 5.2673625e-28 1.4893160e-02 9.7104527e-02\n",
      " 5.3955968e-11 1.6407920e-01 6.1410643e-10 7.2255498e-01 2.5362303e-18\n",
      " 7.9439443e-01 1.7943768e-06 1.5330249e-07 2.0378644e-03 1.4563050e-03\n",
      " 2.1933001e-01 1.9626908e-02 6.4755726e-01 1.4790693e-09 0.0000000e+00\n",
      " 1.9626908e-02 3.1356168e-01]\n",
      "\n",
      "Word: bad -- % perturbed: 1.0\n",
      "Drift? No!\n",
      "p-value: [0.8879386  0.21933001 0.12050407 0.9540582  0.9134755  0.9540582\n",
      " 0.99870795 0.9540582  0.7590978  0.40047103 0.9801618  0.7590978\n",
      " 0.02925058 0.996931   0.9995433  0.79439443 0.26338065 0.04281518\n",
      " 0.93558097 0.14833806 0.50035924 0.82795686 0.18111965 0.43243074\n",
      " 0.99365413 0.9882611  0.9801618  0.99870795 0.96887016 0.10828251\n",
      " 0.07762147 0.9882611 ]\n",
      "\n",
      "Word: bad -- % perturbed: 5.0\n",
      "Drift? Yes!\n",
      "p-value: [7.04859247e-08 5.78442112e-12 7.08821891e-21 1.33834302e-01\n",
      " 7.13247118e-06 3.69972497e-01 9.68870163e-01 1.81119651e-01\n",
      " 2.13202584e-05 3.47411038e-14 5.00359237e-01 1.97830971e-07\n",
      " 9.82534992e-39 1.03241683e-03 1.96269080e-02 2.92505771e-02\n",
      " 8.76041099e-07 8.49670826e-18 1.08282514e-01 3.38379574e-10\n",
      " 8.07501343e-25 5.37760343e-07 2.79573150e-17 2.40344345e-03\n",
      " 1.99518353e-01 7.59097815e-01 8.69054198e-02 3.32311448e-03\n",
      " 2.15581372e-12 3.95873130e-15 1.95523170e-16 5.72654784e-01]\n",
      "\n",
      "Word: horrible -- % perturbed: 1.0\n",
      "Drift? Yes!\n",
      "p-value: [2.63380647e-01 9.98707950e-01 9.98707950e-01 9.88261104e-01\n",
      " 6.47557259e-01 8.59294355e-01 9.96931016e-01 9.13475513e-01\n",
      " 3.50604125e-04 9.99870896e-01 9.99870896e-01 6.09918952e-01\n",
      " 1.33834302e-01 9.80161786e-01 9.35580969e-01 9.88261104e-01\n",
      " 9.71045271e-02 4.00471032e-01 6.85231388e-01 1.81119651e-01\n",
      " 4.65766221e-01 9.80161786e-01 8.69054198e-02 9.96931016e-01\n",
      " 9.99870896e-01 6.91903234e-02 9.80161786e-01 9.99972701e-01\n",
      " 9.93654132e-01 5.32228360e-03 1.20504074e-01 7.22554982e-01]\n",
      "\n",
      "Word: horrible -- % perturbed: 5.0\n",
      "Drift? Yes!\n",
      "p-value: [1.6978088e-14 8.8793862e-01 2.8769323e-01 5.7265478e-01 1.3491673e-04\n",
      " 1.7114086e-02 4.3243074e-01 1.1211077e-02 8.5801831e-33 3.5060412e-04\n",
      " 8.6905420e-02 6.1497598e-09 1.4797455e-32 1.3383430e-01 1.7244401e-03\n",
      " 2.6338065e-01 1.4117470e-08 3.5060412e-04 5.7140245e-15 4.9547091e-14\n",
      " 5.9822431e-37 8.9143086e-06 8.4967083e-18 3.1356168e-01 8.7604110e-07\n",
      " 3.9584363e-20 1.4833806e-01 1.7244401e-03 1.1053569e-12 0.0000000e+00\n",
      " 1.3007273e-15 2.9250577e-02]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for w, probas in X_word.items():\n",
    "    for p, v in probas.items():\n",
    "        preds = cd.predict(v)\n",
    "        print('Word: {} -- % perturbed: {}'.format(w, p))\n",
    "        print('Drift? {}'.format(labels[preds['data']['is_drift']]))\n",
    "        print('p-value: {}'.format(preds['data']['p_val']))\n",
    "        print('')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "nWwYweooxfoO"
   },
   "source": [
    "## MMD TensorFlow detector\n",
    "\n",
    "### Initialize\n",
    "\n",
    "Again check the [images](https://docs.seldon.io/projects/alibi-detect/en/stable/examples/cd_mmd_cifar10.html) example or the [MMD detector documentation](https://docs.seldon.io/projects/alibi-detect/en/stable/cd/methods/mmddrift.html) for more information about each of the possible parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "id": "APKk7FO3xfoO"
   },
   "outputs": [],
   "source": [
    "cd = MMDDrift(X_ref, p_val=.05, preprocess_fn=preprocess_fn, \n",
    "              n_permutations=100, input_shape=(max_len,))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "hm8DtTsnxfoP"
   },
   "source": [
    "### Detect drift\n",
    "\n",
    "*H0*:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "jskCFE1QxfoP",
    "outputId": "5ce043b0-9fdf-4634-cbe4-a041651e5124"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Drift? No!\n",
      "p-value: 0.6\n"
     ]
    }
   ],
   "source": [
    "preds_h0 = cd.predict(X_h0)\n",
    "labels = ['No!', 'Yes!']\n",
    "print('Drift? {}'.format(labels[preds_h0['data']['is_drift']]))\n",
    "print('p-value: {}'.format(preds_h0['data']['p_val']))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "5IaVVZJ8xfoP"
   },
   "source": [
    "Imbalanced data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ukf96vjWxfoP",
    "outputId": "2393abdc-c043-45e2-e3f7-16eafa64deb0"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "% negative sentiment 10.0\n",
      "Drift? Yes!\n",
      "p-value: 0.01\n",
      "\n",
      "% negative sentiment 90.0\n",
      "Drift? Yes!\n",
      "p-value: 0.0\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for k, v in X_imb.items():\n",
    "    preds = cd.predict(v)\n",
    "    print('% negative sentiment {}'.format(k * 100))\n",
    "    print('Drift? {}'.format(labels[preds['data']['is_drift']]))\n",
    "    print('p-value: {}'.format(preds['data']['p_val']))\n",
    "    print('')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1utu6w7oxfoP"
   },
   "source": [
    "Perturbed data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "sP9amuYBxfoP",
    "outputId": "9cab23b5-522f-4942-91a6-177bffde0729",
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Word: fantastic -- % perturbed: 1.0\n",
      "Drift? No!\n",
      "p-value: 0.09\n",
      "\n",
      "Word: fantastic -- % perturbed: 5.0\n",
      "Drift? Yes!\n",
      "p-value: 0.0\n",
      "\n",
      "Word: good -- % perturbed: 1.0\n",
      "Drift? No!\n",
      "p-value: 0.71\n",
      "\n",
      "Word: good -- % perturbed: 5.0\n",
      "Drift? Yes!\n",
      "p-value: 0.0\n",
      "\n",
      "Word: bad -- % perturbed: 1.0\n",
      "Drift? No!\n",
      "p-value: 0.38\n",
      "\n",
      "Word: bad -- % perturbed: 5.0\n",
      "Drift? Yes!\n",
      "p-value: 0.0\n",
      "\n",
      "Word: horrible -- % perturbed: 1.0\n",
      "Drift? No!\n",
      "p-value: 0.18\n",
      "\n",
      "Word: horrible -- % perturbed: 5.0\n",
      "Drift? Yes!\n",
      "p-value: 0.0\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for w, probas in X_word.items():\n",
    "    for p, v in probas.items():\n",
    "        preds = cd.predict(v)\n",
    "        print('Word: {} -- % perturbed: {}'.format(w, p))\n",
    "        print('Drift? {}'.format(labels[preds['data']['is_drift']]))\n",
    "        print('p-value: {}'.format(preds['data']['p_val']))\n",
    "        print('')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4CjTZKNKxfoP"
   },
   "source": [
    "## MMD PyTorch detector\n",
    "\n",
    "### Initialize\n",
    "\n",
    "We can run the same detector with *PyTorch* backend for both the preprocessing step and MMD implementation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "oUdfEom7xfoP",
    "outputId": "66e60675-7a96-491b-e085-6e6a897c5d91"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "# set random seed and device\n",
    "seed = 0\n",
    "torch.manual_seed(seed)\n",
    "torch.cuda.manual_seed(seed)\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "K0ha7VawxfoP",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from alibi_detect.cd.pytorch import preprocess_drift\n",
    "from alibi_detect.models.pytorch import TransformerEmbedding\n",
    "from alibi_detect.cd.pytorch import UAE\n",
    "\n",
    "# Embedding model\n",
    "embedding_pt = TransformerEmbedding(model_name, emb_type, layers)\n",
    "\n",
    "# PyTorch untrained autoencoder\n",
    "uae = UAE(input_layer=embedding_pt, shape=shape, enc_dim=enc_dim)\n",
    "model = uae.to(device).eval()\n",
    "\n",
    "# define preprocessing function\n",
    "preprocess_fn = partial(preprocess_drift, model=model, tokenizer=tokenizer, \n",
    "                        max_len=max_len, batch_size=32, device=device)\n",
    "\n",
    "# initialise drift detector\n",
    "cd = MMDDrift(X_ref, backend='pytorch', p_val=.05, preprocess_fn=preprocess_fn, \n",
    "              n_permutations=100, input_shape=(max_len,))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Q9D2TeJRxfoQ"
   },
   "source": [
    "### Detect drift\n",
    "\n",
    "*H0*:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "kieWKHL4xfoQ",
    "outputId": "f7d47763-6c02-4fee-f750-b5548a736410"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Drift? No!\n",
      "p-value: 0.49000000953674316\n"
     ]
    }
   ],
   "source": [
    "preds_h0 = cd.predict(X_h0)\n",
    "labels = ['No!', 'Yes!']\n",
    "print('Drift? {}'.format(labels[preds_h0['data']['is_drift']]))\n",
    "print('p-value: {}'.format(preds_h0['data']['p_val']))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "DZ_pUMKRxfoQ"
   },
   "source": [
    "Imbalanced data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ZznEh7WSxfoQ",
    "outputId": "13cf9b46-a17b-4fd6-a26c-cefad396d262"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "% negative sentiment 10.0\n",
      "Drift? Yes!\n",
      "p-value: 0.0\n",
      "\n",
      "% negative sentiment 90.0\n",
      "Drift? Yes!\n",
      "p-value: 0.0\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for k, v in X_imb.items():\n",
    "    preds = cd.predict(v)\n",
    "    print('% negative sentiment {}'.format(k * 100))\n",
    "    print('Drift? {}'.format(labels[preds['data']['is_drift']]))\n",
    "    print('p-value: {}'.format(preds['data']['p_val']))\n",
    "    print('')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "s-zKhsRexfoQ"
   },
   "source": [
    "Perturbed data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "tqccOJCSxfoQ",
    "outputId": "77edb585-c88e-4293-ee7a-76daee888c6a"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Word: fantastic -- % perturbed: 1.0\n",
      "Drift? Yes!\n",
      "p-value: 0.0\n",
      "\n",
      "Word: fantastic -- % perturbed: 5.0\n",
      "Drift? Yes!\n",
      "p-value: 0.0\n",
      "\n",
      "Word: good -- % perturbed: 1.0\n",
      "Drift? No!\n",
      "p-value: 0.10000000149011612\n",
      "\n",
      "Word: good -- % perturbed: 5.0\n",
      "Drift? Yes!\n",
      "p-value: 0.0\n",
      "\n",
      "Word: bad -- % perturbed: 1.0\n",
      "Drift? Yes!\n",
      "p-value: 0.0\n",
      "\n",
      "Word: bad -- % perturbed: 5.0\n",
      "Drift? Yes!\n",
      "p-value: 0.0\n",
      "\n",
      "Word: horrible -- % perturbed: 1.0\n",
      "Drift? No!\n",
      "p-value: 0.05999999865889549\n",
      "\n",
      "Word: horrible -- % perturbed: 5.0\n",
      "Drift? Yes!\n",
      "p-value: 0.0\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for w, probas in X_word.items():\n",
    "    for p, v in probas.items():\n",
    "        preds = cd.predict(v)\n",
    "        print('Word: {} -- % perturbed: {}'.format(w, p))\n",
    "        print('Drift? {}'.format(labels[preds['data']['is_drift']]))\n",
    "        print('p-value: {}'.format(preds['data']['p_val']))\n",
    "        print('')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "BBNO_KDKxfoQ"
   },
   "source": [
    "## Train embeddings from scratch\n",
    "\n",
    "So far we used pre-trained embeddings from a BERT model. We can however also use embeddings from a model trained from scratch. First we define and train a simple classification model consisting of an embedding and LSTM layer in *TensorFlow*.\n",
    "\n",
    "### Load data and train model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "id": "V1r8BuPBxfoQ"
   },
   "outputs": [],
   "source": [
    "from tensorflow.keras.datasets import imdb, reuters\n",
    "from tensorflow.keras.layers import Dense, Embedding, Input, LSTM\n",
    "from tensorflow.keras.preprocessing import sequence\n",
    "from tensorflow.keras.utils import to_categorical\n",
    "\n",
    "INDEX_FROM = 3\n",
    "NUM_WORDS = 10000\n",
    "\n",
    "\n",
    "def print_sentence(tokenized_sentence: str, id2w: dict):\n",
    "    print(' '.join(id2w[_] for _ in tokenized_sentence))\n",
    "    print('')\n",
    "    print(tokenized_sentence)\n",
    "\n",
    "\n",
    "def mapping_word_id(data):\n",
    "    w2id = data.get_word_index()\n",
    "    w2id = {k: (v + INDEX_FROM) for k, v in w2id.items()}\n",
    "    w2id[\"<PAD>\"] = 0\n",
    "    w2id[\"<START>\"] = 1\n",
    "    w2id[\"<UNK>\"] = 2\n",
    "    w2id[\"<UNUSED>\"] = 3\n",
    "    id2w = {v: k for k, v in w2id.items()}\n",
    "    return w2id, id2w\n",
    "\n",
    "\n",
    "def get_dataset(dataset: str = 'imdb', max_len: int = 100):\n",
    "    if dataset == 'imdb':\n",
    "        data = imdb\n",
    "    elif dataset == 'reuters':\n",
    "        data = reuters\n",
    "    else:\n",
    "        raise NotImplementedError\n",
    "\n",
    "    w2id, id2w = mapping_word_id(data)\n",
    "\n",
    "    (X_train, y_train), (X_test, y_test) = data.load_data(\n",
    "        num_words=NUM_WORDS, index_from=INDEX_FROM)\n",
    "    X_train = sequence.pad_sequences(X_train, maxlen=max_len)\n",
    "    X_test = sequence.pad_sequences(X_test, maxlen=max_len)\n",
    "    y_train, y_test = to_categorical(y_train), to_categorical(y_test)\n",
    "\n",
    "    return (X_train, y_train), (X_test, y_test), (w2id, id2w)\n",
    "\n",
    "\n",
    "def imdb_model(X: np.ndarray, num_words: int = 100, emb_dim: int = 128,\n",
    "               lstm_dim: int = 128, output_dim: int = 2) -> tf.keras.Model:\n",
    "    X = np.array(X)\n",
    "    inputs = Input(shape=(X.shape[1:]), dtype=tf.float32)\n",
    "    x = Embedding(num_words, emb_dim)(inputs)\n",
    "    x = LSTM(lstm_dim, dropout=.5)(x)\n",
    "    outputs = Dense(output_dim, activation=tf.nn.softmax)(x)\n",
    "    model = tf.keras.Model(inputs=inputs, outputs=outputs)\n",
    "    model.compile(\n",
    "        loss='categorical_crossentropy',\n",
    "        optimizer='adam',\n",
    "        metrics=['accuracy']\n",
    "    )\n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "mm4Xlq8axfoQ"
   },
   "source": [
    "Load and tokenize data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "id": "hKQGbwt-xfoR",
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "(X_train, y_train), (X_test, y_test), (word2token, token2word) = \\\n",
    "    get_dataset(dataset='imdb', max_len=max_len)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vYMCcVqIxfoR"
   },
   "source": [
    "Let's check out an instance:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "BcbldijMxfoR",
    "outputId": "8a0864e3-ab47-48b8-b161-5c0c7a5d6fba"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cry at a film it must have been good and this definitely was also <UNK> to the two little boy's that played the <UNK> of norman and paul they were just brilliant children are often left out of the <UNK> list i think because the stars that play them all grown up are such a big profile for the whole film but these children are amazing and should be praised for what they have done don't you think the whole story was so lovely because it was true and was someone's life after all that was shared with us all\n",
      "\n",
      "[1415   33    6   22   12  215   28   77   52    5   14  407   16   82\n",
      "    2    8    4  107  117 5952   15  256    4    2    7 3766    5  723\n",
      "   36   71   43  530  476   26  400  317   46    7    4    2 1029   13\n",
      "  104   88    4  381   15  297   98   32 2071   56   26  141    6  194\n",
      " 7486   18    4  226   22   21  134  476   26  480    5  144   30 5535\n",
      "   18   51   36   28  224   92   25  104    4  226   65   16   38 1334\n",
      "   88   12   16  283    5   16 4472  113  103   32   15   16 5345   19\n",
      "  178   32]\n"
     ]
    }
   ],
   "source": [
    "print_sentence(X_train[0], token2word)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Y381SpEdxfoR"
   },
   "source": [
    "Define and train a simple model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "Qec5Rv9axfoR",
    "outputId": "26b29ea3-4224-4da9-9383-341486417262"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/2\n",
      "782/782 [==============================] - 17s 17ms/step - loss: 0.4314 - accuracy: 0.7988 - val_loss: 0.3481 - val_accuracy: 0.8474\n",
      "Epoch 2/2\n",
      "782/782 [==============================] - 14s 18ms/step - loss: 0.2707 - accuracy: 0.8908 - val_loss: 0.3858 - val_accuracy: 0.8451\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x7fb07edfef50>"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = imdb_model(X=X_train, num_words=NUM_WORDS, emb_dim=256, lstm_dim=128, output_dim=2)\n",
    "model.fit(X_train, y_train, batch_size=32, epochs=2, \n",
    "          shuffle=True, validation_data=(X_test, y_test))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "8XXdfhN_xfoR"
   },
   "source": [
    "Extract the embedding layer from the trained model and combine with UAE preprocessing step:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "RhkU0K8OxfoR",
    "outputId": "f03ec951-d108-48d4-de6e-a603d9d44a10"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(5, 100, 256)\n"
     ]
    }
   ],
   "source": [
    "embedding = tf.keras.Model(inputs=model.inputs, outputs=model.layers[1].output)\n",
    "x_emb = embedding(X_train[:5])\n",
    "print(x_emb.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "id": "c6oZYWLNxfoR"
   },
   "outputs": [],
   "source": [
    "tf.random.set_seed(0)\n",
    "\n",
    "shape = tuple(x_emb.shape[1:])\n",
    "uae = UAE(input_layer=embedding, shape=shape, enc_dim=enc_dim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_9dv3scexfoR"
   },
   "source": [
    "Again, create reference, *H0* and perturbed datasets. Also test against the *Reuters* news topic classification dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {
    "id": "6QuvJjxxxfoR"
   },
   "outputs": [],
   "source": [
    "X_ref, y_ref = random_sample(X_test, y_test, proba_zero=.5, n=n_sample)\n",
    "X_h0, y_h0 = random_sample(X_test, y_test, proba_zero=.5, n=n_sample)\n",
    "tokens = [word2token[w] for w in words]\n",
    "X_word = {}\n",
    "for i, t in enumerate(tokens):\n",
    "    X_word[words[i]] = {}\n",
    "    for p in perc_chg:\n",
    "        X_word[words[i]][p] = inject_word(t, np.array(X_ref), p, padding='first')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "id": "DnZ45KNLxfoS",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# load and tokenize Reuters dataset\n",
    "(X_reut, y_reut), (w2t_reut, t2w_reut) = \\\n",
    "    get_dataset(dataset='reuters', max_len=max_len)[1:]\n",
    "\n",
    "# sample random instances\n",
    "idx = np.random.choice(X_reut.shape[0], n_sample, replace=False)\n",
    "X_ood = X_reut[idx]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "LGW7I9DWxfoS"
   },
   "source": [
    "### Initialize detector and detect drift"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "EuWYGUAbxfoS",
    "outputId": "bb070dc7-c9b4-4b5c-8d80-36c8a7867577"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Input shape could not be inferred. If alibi_detect.models.tensorflow.embedding.TransformerEmbedding is used as preprocessing step, a saved detector cannot be reinitialized.\n"
     ]
    }
   ],
   "source": [
    "from alibi_detect.cd.tensorflow import preprocess_drift\n",
    "\n",
    "# define preprocess_batch_fn to convert list of str's to np.ndarray to be processed by `model`\n",
    "def convert_list(X: list):\n",
    "    return np.array(X)\n",
    "\n",
    "# define preprocessing function\n",
    "preprocess_fn = partial(preprocess_drift, model=uae, batch_size=128, preprocess_batch_fn=convert_list)\n",
    "\n",
    "# initialize detector\n",
    "cd = KSDrift(X_ref, p_val=.05, preprocess_fn=preprocess_fn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "VzdN0DN6xfoS"
   },
   "source": [
    "*H0*:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ZxPNEB1FxfoS",
    "outputId": "fbd11ff2-d91b-4633-877c-c669abef3c93"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Drift? No!\n",
      "p-value: [0.18111965 0.50035924 0.5360543  0.722555   0.2406036  0.02925058\n",
      " 0.43243074 0.12050407 0.722555   0.60991895 0.19951835 0.60991895\n",
      " 0.50035924 0.79439443 0.722555   0.64755726 0.40047103 0.34099194\n",
      " 0.1338343  0.10828251 0.64755726 0.9995433  0.9540582  0.9134755\n",
      " 0.40047103 0.1640792  0.40047103 0.64755726 0.9134755  0.7590978\n",
      " 0.5726548  0.722555  ]\n"
     ]
    }
   ],
   "source": [
    "preds_h0 = cd.predict(X_h0)\n",
    "labels = ['No!', 'Yes!']\n",
    "print('Drift? {}'.format(labels[preds_h0['data']['is_drift']]))\n",
    "print('p-value: {}'.format(preds_h0['data']['p_val']))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ELV6nSTnxfoS"
   },
   "source": [
    "Perturbed data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "Xnu69wQdxfoS",
    "outputId": "47bc4eb8-ebfa-45d2-d72a-815f9c5a2d02",
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Word: fantastic -- % perturbed: 1.0\n",
      "Drift? No!\n",
      "p-value: [0.9998709  0.7590978  0.99870795 0.9995433  0.9801618  0.9134755\n",
      " 0.82795686 0.99870795 0.9882611  0.8879386  0.9801618  0.79439443\n",
      " 0.85929435 0.96887016 0.9134755  0.996931   0.5726548  0.93558097\n",
      " 0.9882611  0.99870795 0.93558097 0.96887016 0.85929435 0.9882611\n",
      " 0.93558097 0.996931   0.996931   0.96887016 0.9882611  0.96887016\n",
      " 0.8879386  0.996931  ]\n",
      "\n",
      "Word: fantastic -- % perturbed: 5.0\n",
      "Drift? No!\n",
      "p-value: [0.85929435 0.06155144 0.9540582  0.79439443 0.43243074 0.6852314\n",
      " 0.722555   0.9134755  0.28769323 0.996931   0.60991895 0.19951835\n",
      " 0.43243074 0.64755726 0.722555   0.8879386  0.18111965 0.18111965\n",
      " 0.43243074 0.14833806 0.50035924 0.43243074 0.01489316 0.01121108\n",
      " 0.722555   0.46576622 0.07762147 0.8879386  0.05464633 0.10828251\n",
      " 0.03327804 0.9801618 ]\n",
      "\n",
      "Word: good -- % perturbed: 1.0\n",
      "Drift? No!\n",
      "p-value: [0.99365413 0.8879386  0.99870795 0.9801618  0.99870795 0.99870795\n",
      " 0.9134755  0.93558097 0.8879386  0.9995433  0.93558097 0.996931\n",
      " 0.99999607 0.9995433  0.99870795 0.9801618  0.99870795 0.9801618\n",
      " 0.8879386  0.996931   0.9134755  0.996931   0.7590978  0.99365413\n",
      " 0.9540582  0.99870795 0.99870795 0.9998709  0.9801618  0.64755726\n",
      " 0.9999727  0.8879386 ]\n",
      "\n",
      "Word: good -- % perturbed: 5.0\n",
      "Drift? No!\n",
      "p-value: [0.9882611  0.6852314  0.79439443 0.60991895 0.28769323 0.3699725\n",
      " 0.28769323 0.6852314  0.79439443 0.31356168 0.99870795 0.85929435\n",
      " 0.34099194 0.34099194 0.8879386  0.996931   0.96887016 0.96887016\n",
      " 0.9540582  0.722555   0.19951835 0.9995433  0.3699725  0.722555\n",
      " 0.1338343  0.9134755  0.5360543  0.26338065 0.85929435 0.2406036\n",
      " 0.31356168 0.6852314 ]\n",
      "\n",
      "Word: bad -- % perturbed: 1.0\n",
      "Drift? No!\n",
      "p-value: [0.93558097 0.996931   0.85929435 0.9540582  0.50035924 0.64755726\n",
      " 0.82795686 0.85929435 0.82795686 0.9882611  0.82795686 0.9540582\n",
      " 0.21933001 0.96887016 0.93558097 0.99870795 0.79439443 0.722555\n",
      " 0.93558097 0.93558097 0.64755726 0.99365413 0.5726548  0.9998709\n",
      " 0.93558097 0.96887016 0.9995433  0.99365413 0.7590978  0.93558097\n",
      " 0.9882611  0.9134755 ]\n",
      "\n",
      "Word: bad -- % perturbed: 5.0\n",
      "Drift? Yes!\n",
      "p-value: [4.00471032e-01 8.27956855e-01 2.87693232e-01 6.47557259e-01\n",
      " 3.89581337e-03 1.03241683e-03 3.40991944e-01 7.59097815e-01\n",
      " 2.82894098e-03 5.46463318e-02 1.20504074e-01 2.63380647e-01\n",
      " 1.11190266e-05 5.46463318e-02 4.65766221e-01 7.94394433e-01\n",
      " 9.69783217e-03 3.69972497e-01 9.35580969e-01 1.71140861e-02\n",
      " 6.91903234e-02 7.94394433e-01 9.07998619e-05 4.00471032e-01\n",
      " 8.27956855e-01 7.59097815e-01 1.64079204e-01 4.84188050e-02\n",
      " 1.71140861e-02 6.85231388e-01 5.46463318e-02 5.72654784e-01]\n",
      "\n",
      "Word: horrible -- % perturbed: 1.0\n",
      "Drift? No!\n",
      "p-value: [0.996931   0.9801618  0.96887016 0.79439443 0.79439443 0.5726548\n",
      " 0.82795686 0.996931   0.43243074 0.93558097 0.79439443 0.82795686\n",
      " 0.06919032 0.3699725  0.96887016 0.9540582  0.5360543  0.6852314\n",
      " 0.60991895 0.79439443 0.9540582  0.9801618  0.40047103 0.5726548\n",
      " 0.82795686 0.8879386  0.9540582  0.9134755  0.99365413 0.60991895\n",
      " 0.82795686 0.79439443]\n",
      "\n",
      "Word: horrible -- % perturbed: 5.0\n",
      "Drift? Yes!\n",
      "p-value: [4.00471032e-01 1.48931602e-02 4.84188050e-02 1.96269080e-02\n",
      " 1.12110768e-02 1.48931602e-02 4.00471032e-01 5.72654784e-01\n",
      " 1.45630504e-03 1.96269080e-02 7.59097815e-01 1.72444014e-03\n",
      " 1.30072730e-15 1.79437677e-06 2.63380647e-01 6.47557259e-01\n",
      " 1.11478073e-06 1.99518353e-01 1.20504074e-01 4.55808453e-03\n",
      " 7.21312594e-03 2.40603596e-01 2.24637091e-02 4.28151786e-02\n",
      " 4.28151786e-02 7.22554982e-01 1.08282514e-01 9.07998619e-05\n",
      " 5.36054313e-01 9.71045271e-02 1.64079204e-01 3.40991944e-01]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for w, probas in X_word.items():\n",
    "    for p, v in probas.items():\n",
    "        preds = cd.predict(v)\n",
    "        print('Word: {} -- % perturbed: {}'.format(w, p))\n",
    "        print('Drift? {}'.format(labels[preds['data']['is_drift']]))\n",
    "        print('p-value: {}'.format(preds['data']['p_val']))\n",
    "        print('')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ocKUZUZLxfoS"
   },
   "source": [
    "The detector is not as sensitive as the Transformer-based K-S drift detector. The embeddings trained from scratch only trained on a small dataset and a simple model with cross-entropy loss function for 2 epochs. The pre-trained BERT model on the other hand captures semantics of the data better.\n",
    "\n",
    "Sample from the Reuters dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "-ABzIgu4xfoS",
    "outputId": "090f7532-f86a-4e3d-d7dc-aa0248b0a674"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Drift? Yes!\n",
      "p-value: [7.22554982e-01 1.07232365e-08 3.69972497e-01 9.54058170e-01\n",
      " 7.22554982e-01 4.84188050e-02 9.69783217e-03 1.71956726e-05\n",
      " 8.87938619e-01 4.01514189e-05 2.54783203e-07 1.22740539e-03\n",
      " 4.21853358e-04 3.49877549e-09 5.46463318e-02 1.79437677e-06\n",
      " 6.91903234e-02 4.20066499e-07 3.50604125e-04 2.87693232e-01\n",
      " 1.69780876e-14 1.69780876e-14 3.40991944e-01 2.53623026e-18\n",
      " 2.26972293e-06 3.18301190e-08 2.40344345e-03 5.32228360e-03\n",
      " 2.40725611e-04 2.56591532e-02 3.27475419e-07 5.69539361e-06]\n"
     ]
    }
   ],
   "source": [
    "preds_ood = cd.predict(X_ood)\n",
    "labels = ['No!', 'Yes!']\n",
    "print('Drift? {}'.format(labels[preds_ood['data']['is_drift']]))\n",
    "print('p-value: {}'.format(preds_ood['data']['p_val']))"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "cd_text_imdb.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
